"""Tool management for the agent core runtime."""

import json
import logging
import os
from typing import Any

import boto3
from mcp import StdioServerParameters, stdio_client
from strands import tool
from strands.tools.mcp import MCPClient

from .config import WORKSPACE_DIR, get_aws_credentials, get_uv_environment

# Import strands-agents code interpreter tool
try:
    from strands_tools.code_interpreter import AgentCoreCodeInterpreter

    CODE_INTERPRETER_AVAILABLE = True
except ImportError as e:
    CODE_INTERPRETER_AVAILABLE = False
    logger = logging.getLogger(__name__)
    logger.warning(f"Strands code interpreter tool not available: {e}")
    AgentCoreCodeInterpreter = None

logger = logging.getLogger(__name__)


class ToolManager:
    """Manages tools including MCP tools and built-in tools."""

    def __init__(self):
        self.mcp_tools = None
        self.session_id = None
        self.trace_id = None

    def set_session_info(self, session_id: str, trace_id: str):
        """Set session and trace IDs for tool operations"""
        self.session_id = session_id
        self.trace_id = trace_id

    def load_mcp_tools(self) -> list[Any]:
        """Load MCP tools from environment variable or mcp.json file"""
        if self.mcp_tools is not None:
            return self.mcp_tools

        try:
            # First try to load from environment variable
            mcp_servers_env = os.environ.get("MCP_SERVERS")
            if mcp_servers_env:
                logger.info("Loading MCP configuration from environment variable")
                mcp_servers = json.loads(mcp_servers_env)
            else:
                # Fallback to mcp.json file
                logger.info("Loading MCP configuration from mcp.json file")
                mcp_config_path = "mcp.json"
                if not os.path.exists(mcp_config_path):
                    logger.warning(f"MCP configuration file not found at {mcp_config_path}")
                    self.mcp_tools = []
                    return self.mcp_tools

                with open(mcp_config_path) as f:
                    mcp_config = json.load(f)
                mcp_servers = mcp_config.get("mcpServers", {})

            mcp_clients = []
            uv_env = get_uv_environment()

            for server_name, server in mcp_servers.items():
                try:
                    client = MCPClient(
                        lambda server=server: stdio_client(
                            StdioServerParameters(
                                command=server["command"],
                                args=server.get("args", []),
                                env={**uv_env, **server.get("env", {})},
                            )
                        )
                    )
                    client.start()
                    mcp_clients.append(client)
                except Exception as e:
                    logger.error(f"Error creating MCP client for {server_name}: {e}")

            # Flatten the tools
            self.mcp_tools = sum([c.list_tools_sync() for c in mcp_clients], [])
            logger.info(f"Loaded {len(self.mcp_tools)} MCP tools")
            return self.mcp_tools

        except Exception as e:
            logger.error(f"Error loading MCP tools: {e}")
            self.mcp_tools = []
            return self.mcp_tools

    def load_mcp_tools_by_names(self, server_names: list[str]) -> list[Any]:
        """Load MCP tools from environment variable or mcp.json by server names"""
        if not server_names:
            return []

        try:
            # First try to load from environment variable
            mcp_servers_env = os.environ.get("MCP_SERVERS")
            if mcp_servers_env:
                logger.info("Loading MCP configuration from environment variable")
                available_servers = json.loads(mcp_servers_env)
            else:
                # Fallback to mcp.json file
                logger.info("Loading MCP configuration from mcp.json file")
                mcp_config_path = "mcp.json"
                logger.info(f"Loading MCP configuration from: {mcp_config_path}")
                with open(mcp_config_path) as f:
                    mcp_config = json.load(f)
                available_servers = mcp_config.get("mcpServers", {})

            logger.info(f"Found {len(available_servers)} available MCP servers")
            mcp_clients = []
            uv_env = get_uv_environment()

            for server_name in server_names:
                if server_name not in available_servers:
                    logger.warning(f"MCP server '{server_name}' not found in configuration")
                    continue

                server_config = available_servers[server_name]
                try:
                    client = MCPClient(
                        lambda server=server_config: stdio_client(
                            StdioServerParameters(
                                command=server["command"],
                                args=server.get("args", []),
                                env={**uv_env, **server.get("env", {})},
                            )
                        )
                    )
                    client.start()
                    mcp_clients.append(client)
                    logger.info(f"Successfully loaded MCP server: {server_name}")
                except Exception as e:
                    logger.error(f"Error creating MCP client for {server_name}: {e}")

            # Flatten the tools
            dynamic_tools = sum([c.list_tools_sync() for c in mcp_clients], [])
            logger.info(f"Loaded {len(dynamic_tools)} MCP tools from {len(mcp_clients)} servers")
            return dynamic_tools

        except Exception as e:
            logger.error(f"Error loading MCP tools by names: {e}")
            return []

    def get_upload_tool(self):
        """Get the S3 upload tool with session context"""
        trace_id = self.trace_id

        @tool
        def upload_file_to_s3_and_retrieve_s3_url(filepath: str) -> str:
            """Upload the file at /tmp/ws/* and retrieve the s3 path

            Args:
                filepath: The path to the uploading file
            """
            bucket = os.environ.get("FILE_BUCKET")
            if not bucket:
                # For local testing, provide a fallback message
                logger.warning("FILE_BUCKET environment variable not set. Using local file path for testing.")
                return f"Local file path (S3 upload skipped): {filepath}"

            aws_creds = get_aws_credentials()
            region = aws_creds.get("AWS_REGION", "us-east-1")

            if not filepath.startswith(WORKSPACE_DIR):
                raise ValueError(f"{filepath} does not appear to be a file under the {WORKSPACE_DIR} directory. Files to be uploaded must exist under {WORKSPACE_DIR}.")

            try:
                filename = os.path.basename(filepath)
                key = f"agentcore/{trace_id}/{filename}"

                s3 = boto3.client("s3", region_name=region)
                s3.upload_file(filepath, bucket, key)

                return f"https://{bucket}.s3.{region}.amazonaws.com/{key}"
            except Exception as e:
                logger.error(f"Error uploading file to S3: {e}")
                # For local testing, provide a fallback
                return f"Error uploading to S3: {str(e)}. Local file path: {filepath}"

        return upload_file_to_s3_and_retrieve_s3_url

    def get_code_interpreter_tool(self) -> list[Any]:
        """Get code interpreter tool if available"""
        code_interpreter_tools = []

        if CODE_INTERPRETER_AVAILABLE and AgentCoreCodeInterpreter:
            try:
                aws_creds = get_aws_credentials()
                region = aws_creds.get("AWS_REGION", "us-east-1")
                code_interpreter = AgentCoreCodeInterpreter(region=region)
                code_interpreter_tools.append(code_interpreter.code_interpreter)
                logger.info("Added code_interpreter tool (AgentCoreCodeInterpreter)")
            except Exception as e:
                logger.warning(f"Failed to initialize AgentCoreCodeInterpreter: {e}")

        return code_interpreter_tools

    def get_tools_with_options(self, code_execution_enabled: bool = False, mcp_servers=None) -> list[Any]:
        """
        Get tools with optional code execution and MCP servers.

        Args:
            code_execution_enabled: Whether to include code interpreter tools
            mcp_servers: MCP server configurations
                - None: Load default MCP servers from mcp.json
                - []: Empty list, no MCP servers (File Upload only)
                - [...]: Load specified MCP servers

        Returns:
            List of all available tools
        """
        logger.info(f"get_tools_with_options called with code_execution_enabled={code_execution_enabled}")
        logger.info(f"mcp_servers parameter: {mcp_servers} (type: {type(mcp_servers)})")

        all_tools = []

        # Handle MCP servers based on parameter
        if mcp_servers is None:
            # Load default MCP servers from mcp.json
            logger.info("Loading default MCP servers from mcp.json")
            mcp_tools = self.load_mcp_tools()
        elif isinstance(mcp_servers, list) and len(mcp_servers) == 0:
            # Empty list: no MCP servers
            logger.info("Empty MCP servers list provided, skipping MCP tools")
            mcp_tools = []
        elif isinstance(mcp_servers, list):
            # Load specified MCP servers by name
            logger.info(f"Loading {len(mcp_servers)} user-specified MCP servers by name")
            mcp_tools = self.load_mcp_tools_by_names(mcp_servers)
        else:
            # Fallback to default
            logger.warning(f"Unexpected mcp_servers type: {type(mcp_servers)}, using default")
            mcp_tools = self.load_mcp_tools()

        all_tools.extend(mcp_tools)

        # Add built-in tools (always included)
        upload_tool = self.get_upload_tool()
        all_tools.append(upload_tool)

        # Add code interpreter tools if enabled
        code_interpreter_tools = []
        if code_execution_enabled:
            code_interpreter_tools = self.get_code_interpreter_tool()
            all_tools.extend(code_interpreter_tools)

        # Log final tool count
        logger.info(f"Total tools loaded: {len(all_tools)} (MCP: {len(mcp_tools)}, Built-in: 1, Code Interpreter: {len(code_interpreter_tools)} - {'enabled' if code_execution_enabled else 'disabled'})")

        return all_tools
